# setup and maps ----------
fl50 = read_rds(here("data/fl50_cluster.rds")) %>%
    select(pop, vap, pop_black=BlackPop, pop_hisp=HispPop,
           vap_black=BlackVAP, vap_hisp=HispVAP, obama, mccain, geometry) %>%
    redist_map(pop_tol=0.1, ndists=4) %>%
    suppressWarnings()

plot(fl50) + coord_sf(expand=FALSE) + theme(plot.margin=unit(rep(0, 4), "cm"))
ggsave("figures/fl50_map.pdf", width=3, height=3)


# enumerate all valid plans -----
if (!file.exists(path_enum <- here("data/fl50_enum4.dat"))) {
    # 112,515,494 solutions
    dir.create(here("out/FL"), recursive=TRUE)
    enum_raw = redist.enumpart(fl50$adj, here("out/FL/unordered"), 
                               here("out/FL/ordered"), here("out/FL/enum4"), 
                               ndists=4, all=TRUE,
                               weight_path=here("out/FL/pop"),
                               lower=attr(fl50, "pop_bounds")[1], 
                               upper=attr(fl50, "pop_bounds")[3], 
                               total_pop=fl50$pop)
} else {
    enum_raw = list(plans=1L+matrix(scan(path_enum), nrow=50))
}
log_st_enum = by_plan(comp_log_st(enum_raw$plans, fl50))
w = exp(log_st_enum - max(log_st_enum))
w = w / sum(w)

pl_enum = redist_plans(enum_raw$plans, fl50, "enumpart", wgt=w) %>%
    mutate(dev = plan_parity(fl50),
           comp = comp_frac_kept(., fl50),
           seg = seg_dissim(., fl50, mccain, obama+mccain)) %>%
    as_tibble() %>% 
    distinct(draw, dev, comp, seg)

# visualize Rep Dissim vs Compactness
ggplot(pl_enum, aes(comp, seg)) +
    geom_density2d(color="#444444", linewidth=0.6, bins=24, adjust=1.1) +
    coord_cartesian(expand=FALSE) +
    labs(x="Fraction of edges kept", y="Republican dissimilarity index") +
    theme_bw(base_size=10, base_family="Times")
ggsave("figures/fl50_seg_rem.pdf", width=3, height=3)



# run simulations ------
comp_seq = seq(0.8, 1.2, by=0.1)
N = 1500

res_smc = map(comp_seq, function(rho) {
    out = redist_smc(fl50, N, compactness=rho, ncores=2, runs=4L,
                        adapt_k_thresh=0.999, silent=TRUE) %>%
        mutate(comp = comp_frac_kept(., fl50),
               seg = seg_dissim(., fl50, mccain, obama+mccain)) %>%
        as_tibble() %>% 
        distinct(draw, chain, comp, seg)
    attr(out, "plans") = NULL
    cat("R-hat (rho =", rho, "):", 
        redist:::diag_rhat(out$seg, out$chain), "\n")
    out
})

res_mcmc = map(comp_seq, function(rho) {
    out = redist_mergesplit_parallel(fl50, N+500, warmup=500, compactness=rho, 
                 chains=4L, adapt_k_thresh=0.999, init_name=FALSE, silent=TRUE) %>%
        mutate(comp = comp_frac_kept(., fl50),
               seg = seg_dissim(., fl50, mccain, obama+mccain)) %>%
        as_tibble() %>% 
        distinct(draw, chain, comp, seg)
    attr(out, "plans") = NULL
    cat("R-hat (rho =", rho, "):", 
        redist:::diag_rhat(out$seg, out$chain), "\n")
    out
})

# plots ------

plots = imap(comp_seq, function(rho, i) {
    set.seed(5118)
    rs_idx = sample(length(w), 100*N, replace=TRUE, prob=w^rho)
    lab = paste(" =", rho)
    
    tibble(SMC = sort(rep(res_smc[[i]]$seg, each=25)),
           MCMC = sort(rep(res_mcmc[[i]]$seg, each=25)),
           true = sort(pl_enum$seg[rs_idx])) %>%
        slice_sample(n=5000) %>%
    pivot_longer(SMC:MCMC, names_to="alg", values_to="est") %>%
    ggplot(aes(true, est, color=alg)) +
        geom_abline(slope=1, color="red") + 
        geom_point(size=0.4, alpha=0.2) +
        scale_x_continuous(str_glue("True dissimilarity")) +
        scale_y_continuous(str_glue("Sampled dissimilarity")) +
        scale_color_manual(values=c(PAL[2], "#444444")) +
        labs(title=rlang::expr(rho * !!lab), color="Algorithm") +
        guides(color=guide_legend(override.aes=list(alpha=1, size=4))) +
        theme_bw(base_family="Times", base_size=10)
})

wrap_plots(plots) + guide_area() + plot_layout(guides="collect")

ggsave(here("figures/fl50_valid.pdf"), width=7.5, height=5)


# plan-space validation ----------
to_lbl = function(x) paste(x, collapse="")
idx_bal = which(pl_enum$dev <= 0.01)
ids_enum = apply(enum_raw$plans[, idx_bal], 2, to_lbl)

set.seed(1812)
pl_smc = redist_smc(set_pop_tol(fl50, 0.01), 50e3, compactness=1.0, 
                    ncores=6, adapt_k_thresh=0.999, verbose=TRUE)

renumb = apply(as.matrix(pl_smc), 2, \(x) order(unique(x)))
m_renumb = redist:::renumber_matrix(as.matrix(pl_smc), renumb)
smc_enum = match(apply(m_renumb, 2, to_lbl), ids_enum)

tgt_dist = w[idx_bal] / sum(w[idx_bal])
samp_dist = tabulate(smc_enum, nbins=38L) 
samp_dist = samp_dist / sum(samp_dist)

sum(abs(tgt_dist - samp_dist))/2
